{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GNN Explainer\n",
    "\n",
    "This notebook is designed to visualize the results of the GNN Explainer.\n",
    "\n",
    "Use it after one has trained the model using train.py, and has run the explainer optimization (explainer_main.py).\n",
    "The main purpose is to visualize the trained mask by interactively tuning the threshold. In many scientific applications, the explanation size is unknown a priori. This tool can help user visualize the selected subgraph, with respect to different values of the thresholds, and find the right size for a good explanation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ipywidgets import interact, interactive, fixed, interact_manual\n",
    "import ipywidgets as widgets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Configuring the experiment you want to visualize. These values should match the configuration:\n",
    "\n",
    "> TODO: Unify configuration of experiments in yaml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "logdir = '../log/'\n",
    "expdir = 'syn2_base_h20_o20_explain'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the produced masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "dirs = os.listdir(os.path.join(logdir, expdir))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "masked_adjsyn2_base_h20_o20_explain.npy\n",
      "masked_adj.npy\n"
     ]
    }
   ],
   "source": [
    "masks = []\n",
    "# This would print all the files and directories\n",
    "for file in dirs:\n",
    "    if file.split('.')[-1] == 'npy':\n",
    "        print(file)\n",
    "        masks.append(file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Utility to save masks:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "metadata": {},
   "outputs": [],
   "source": [
    "from networkx.readwrite import json_graph\n",
    "\n",
    "def save_mask(G, fname, fmt='json', suffix=''):\n",
    "    pth = os.path.join(logdir, expdir, fname+'-filt-'+suffix+'.'+fmt)\n",
    "    if fmt == 'json':\n",
    "        dt = json_graph.node_link_data(G)\n",
    "        with open(pth, 'w') as f:\n",
    "            json.dump(dt, f)\n",
    "    elif fmt == 'pdf':\n",
    "        plt.savefig(pth)\n",
    "    elif fmt == 'npy':\n",
    "        np.save(pth, nx.to_numpy_array(G))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plotting utilities:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_adjacency_full(mask, ax=None):\n",
    "    adj = np.load(os.path.join(logdir, expdir, mask), allow_pickle=True)\n",
    "    if ax is None:\n",
    "        plt.figure()\n",
    "        plt.imshow(adj);\n",
    "    else:\n",
    "        ax.imshow(adj)\n",
    "    return adj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_adjacency_full(mask, ax=None):\n",
    "    adj = np.load(os.path.join(logdir, expdir, mask), allow_pickle=True)\n",
    "    return adj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d87c030bb64e4c4b940f655300d411b2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(FloatSlider(value=0.5, description='thresh', max=1.5, min=-0.5), Output()), _dom_classes…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "filt_adj = read_adjacency_full(masks[0])\n",
    "@interact\n",
    "def filter_adj(thresh=0.5):\n",
    "    filt_adj[filt_adj<thresh] = 0\n",
    "    return filt_adj"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Weight-based threshold:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9dd0f2ae93a54b279e8c081e2062abc2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(FloatSlider(value=0.5, description='thresh', max=1.0, step=0.01), Output()), _dom_classe…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# EDIT THIS INDEX\n",
    "MASK_IDX = 0\n",
    "# EDIT THIS INDEX\n",
    "\n",
    "m = masks[MASK_IDX]\n",
    "adj = read_adjacency_full(m)\n",
    "\n",
    "\n",
    "@interact(thresh=widgets.FloatSlider(value=0.5, min=0.0, max=1.0, step=0.01))\n",
    "def plot_interactive(thresh=0.5):\n",
    "    filt_adj = read_adjacency_full(m)\n",
    "    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15,5))\n",
    "    plt.title(str(m));\n",
    "\n",
    "    # Full adjacency\n",
    "    ax1.set_title('Full Adjacency mask')\n",
    "    adj = show_adjacency_full(m, ax=ax1);\n",
    "    \n",
    "    # Filtered adjacency\n",
    "    filt_adj[filt_adj<thresh] = 0\n",
    "    ax2.set_title('Filtered Adjacency mask');\n",
    "    ax2.imshow(filt_adj);\n",
    "    \n",
    "    # Plot subgraph\n",
    "    ax3.set_title(\"Subgraph\")\n",
    "    G_ = nx.from_numpy_array(adj)\n",
    "    G  = nx.from_numpy_array(filt_adj)\n",
    "    G.remove_nodes_from(list(nx.isolates(G)))\n",
    "    nx.draw(G, ax=ax3)\n",
    "    save_mask(G, fname=m, fmt='json')\n",
    "    \n",
    "    print(\"Removed {} edges -- K = {} remain.\".format(G_.number_of_edges()-G.number_of_edges(), G.number_of_edges()))\n",
    "    print(\"Removed {} nodes -- K = {} remain.\".format(G_.number_of_nodes()-G.number_of_nodes(), G.number_of_nodes()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
